{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting /data\\train-images-idx3-ubyte.gz\n",
      "Extracting /data\\train-labels-idx1-ubyte.gz\n",
      "Extracting /data\\t10k-images-idx3-ubyte.gz\n",
      "Extracting /data\\t10k-labels-idx1-ubyte.gz\n"
     ]
    }
   ],
   "source": [
    "# 输入数据\n",
    "from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets\n",
    "mnist=read_data_sets('/data',one_hot=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 定义网络的超参数\n",
    "learning_rate=0.01 #学习率\n",
    "training_iters=20000 #训练的数据量\n",
    "batch_size=128 #每次训练多少数据\n",
    "display_step=10  #每多少次显示一下当前状态"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 定义网络的结构参数\n",
    "n_input=784\n",
    "n_classes=10\n",
    "dropout=0.75"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "#设定数据占位符\n",
    "x=tf.placeholder(tf.float32,[None,n_input])\n",
    "y=tf.placeholder(tf.float32,[None,n_classes])\n",
    "keep_prob=tf.placeholder(tf.float32)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 构建网络模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义卷积操作（Conv layer）\n",
    "def conv2d(x,W,b,strides=1):\n",
    "    x=tf.nn.conv2d(x,W,strides=[1,strides,strides,1],padding='SAME')\n",
    "    x=tf.nn.bias_add(x,b)\n",
    "    return tf.nn.relu(x)\n",
    "# 定义池化操作\n",
    "def maxpool2d(x,k=2):\n",
    "    return tf.nn.max_pool(x,ksize=[1,k,k,1],strides=[1,k,k,1],padding='SAME')\n",
    "\n",
    "#局部归一化\n",
    "def norm(x,lsize=4):\n",
    "    return tf.nn.lrn(pool1,lsize,bias=1.0,alpha=0.001/9.0,beta=0.75)\n",
    "\n",
    "# 定义网络的权重和偏置参数\n",
    "weights={\n",
    "    'wc1':tf.Variable(tf.random_normal([11,11,1,96])),\n",
    "    'wc2':tf.Variable(tf.random_normal([5,5,96,256])),\n",
    "    'wc3':tf.Variable(tf.random_normal([3,3,256,384])),\n",
    "    'wc4':tf.Variable(tf.random_normal([3,3,384,384])),\n",
    "    'wc5':tf.Variable(tf.random_normal([3,3,384,256])),\n",
    "    'wd1':tf.Variable(tf.random_normal([2*2*256,4096])),\n",
    "    'wd2':tf.Variable(tf.random_normal([4096,4096])),\n",
    "    'out':tf.Variable(tf.random_normal([4096,n_classes]))\n",
    "}\n",
    "biases={\n",
    "    'bc1':tf.Variable(tf.random_normal([96])),\n",
    "    'bc2':tf.Variable(tf.random_normal([256])),\n",
    "    'bc3':tf.Variable(tf.random_normal([384])),\n",
    "    'bc4':tf.Variable(tf.random_normal([384])),\n",
    "    'bc5':tf.Variable(tf.random_normal([256])),\n",
    "    'bd1':tf.Variable(tf.random_normal([4096])),\n",
    "    'bd2':tf.Variable(tf.random_normal([4096])),\n",
    "    'out':tf.Variable(tf.random_normal([n_classes]))\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 定义Alexnet网络结构\n",
    "def alex_net(x,weights,biases,dropout):\n",
    "    # 输出的数据做reshape\n",
    "    x=tf.reshape(x,shape=[-1,28,28,1])\n",
    "    \n",
    "    #第一层卷积计算（conv+relu+pool）\n",
    "    # 卷积\n",
    "    conv1=conv2d(x,weights['wc1'],biases['bc1'])\n",
    "    # 池化\n",
    "    pool1=maxpool2d(conv1,k=2)\n",
    "    # 规范化，局部归一化\n",
    "    # 局部归一化是仿造生物学上的活跃的神经元对相邻神经元的抑制现象\n",
    "    norm1=norm(pool1)\n",
    "    \n",
    "    #第二层卷积\n",
    "    conv2=conv2d(norm1,weights['wc2'],biases['bc2'])\n",
    "    # 池化\n",
    "    pool2=maxpool2d(conv2,k=2)\n",
    "    norm2=norm(pool2)\n",
    "    \n",
    "    #第三层卷积\n",
    "    conv3=conv2d(norm2,weights['wc3'],biases['bc3'])\n",
    "    # 池化\n",
    "    pool3=maxpool2d(conv3,k=2)\n",
    "    norm3=norm(pool3)\n",
    "    \n",
    "    #第四层卷积\n",
    "    conv4=conv2d(norm3,weights['wc4'],biases['bc4'])\n",
    "    #第五层卷积\n",
    "    conv5=conv2d(conv4,weights['wc5'],biases['bc5'])\n",
    "    # 池化\n",
    "    pool5=maxpool2d(conv5,k=2)\n",
    "    norm5=norm(pool5)\n",
    "    #可以再加上dropout\n",
    "    \n",
    "    #全连接1\n",
    "    # 向量化\n",
    "    fc1=tf.reshape(norm5,[-1,weights['wd1'].get_shape().as_list()[0]])\n",
    "    fc1=tf.add(tf.matmul(fc1,weights['wd1']),biases['bd1'])\n",
    "    fc1=tf.nn.relu(fc1)\n",
    "    #dropout\n",
    "    fc1=tf.nn.dropout(fc1,dropout)\n",
    "    \n",
    "    #全连接2\n",
    "    ## 向量化\n",
    "    fc2=tf.reshape(fc1,[-1,weights['wd2'].get_shape().as_list()[0]])\n",
    "    fc2=tf.add(tf.matmul(fc2,weights['wd2']),biases['bd2'])\n",
    "    fc2=tf.nn.relu(fc2)\n",
    "    #dropout\n",
    "    fc2=tf.nn.dropout(fc2,dropout)\n",
    "    \n",
    "    #out\n",
    "    return tf.add(tf.matmul(fc2,weights['out']),biases['out'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 训练和评估模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 1.定义损失函数和优化器，并构建评估函数\n",
    "# （1）构建模型\n",
    "pred=alex_net(x,weights,biases,keep_prob)\n",
    "# (2)损失函数和优化器\n",
    "cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=_pred,labels=y))\n",
    "optim=tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)\n",
    "#(3)评估函数\n",
    "correct_pred=tf.equal(tf.argmax(pred,1),tf.argmax(y,1))\n",
    "acc=tf.reduce_mean(tf.cast(correct_pred,tf.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iter:1280,Loss:108323.960938,Train Acc:0.085938\n",
      "Iter:2560,Loss:30001.238281,Train Acc:0.109375\n",
      "Iter:3840,Loss:13807.622070,Train Acc:0.085938\n",
      "Iter:5120,Loss:8310.585938,Train Acc:0.101562\n",
      "Iter:6400,Loss:5273.828125,Train Acc:0.070312\n",
      "Iter:7680,Loss:3571.667969,Train Acc:0.140625\n",
      "Iter:8960,Loss:2524.023193,Train Acc:0.140625\n",
      "Iter:10240,Loss:2586.094971,Train Acc:0.093750\n",
      "Iter:11520,Loss:1627.476562,Train Acc:0.226562\n",
      "Iter:12800,Loss:1126.197021,Train Acc:0.265625\n",
      "Iter:14080,Loss:905.920776,Train Acc:0.187500\n",
      "Iter:15360,Loss:655.510010,Train Acc:0.195312\n",
      "Iter:16640,Loss:394.565369,Train Acc:0.281250\n",
      "Iter:17920,Loss:337.006927,Train Acc:0.296875\n",
      "Iter:19200,Loss:209.208725,Train Acc:0.375000\n",
      "Optimization finished\n"
     ]
    }
   ],
   "source": [
    "# 训练\n",
    "init=tf.global_variables_initializer()\n",
    "with tf.Session() as sess:\n",
    "    sess.run(init)\n",
    "    step=1\n",
    "    # 开始训练，直到达到training_iters\n",
    "    while step*batch_size<training_iters:\n",
    "        batch_x,batch_y=mnist.train.next_batch(batch_size)\n",
    "        sess.run(optim,feed_dict={x:batch_x,y:batch_y,keep_prob:dropout})\n",
    "        if step%display_step==0:\n",
    "            # 显示一下当前的损失和正确率\n",
    "            loss,acc_num=sess.run([cost,acc],feed_dict={x:batch_x,y:batch_y,keep_prob:dropout})\n",
    "            print('Iter:%d,Loss:%f,Train Acc:%f'%(step*batch_size,loss,acc_num))\n",
    "        step+=1\n",
    "    print('Optimization finished')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    ""
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [conda root]",
   "language": "python",
   "name": "conda-root-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3.0
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}